Optimizing the Text Generation Model

You've already done some amazing work with generating new songs, but so far we've seen some issues with repetition and a fair amount of incoherence. By using more data and further tweaking the model, you'll be able to get improved results. We'll once again use the Kaggle Song Lyrics Dataset here.

In [1]:
import tensorflow as tf

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Other imports for processing data
import string
import numpy as np
import pandas as pd

Get the Dataset

As noted above, we'll utilize the Song Lyrics dataset on Kaggle again.

In [2]:
!wget --no-check-certificate \
    https://drive.google.com/uc?id=1LiJFZd41ofrWoBtW-pMYsfz1w8Ny0Bj8 \
    -O /tmp/songdata.csv
--2020-08-09 03:56:43--  https://drive.google.com/uc?id=1LiJFZd41ofrWoBtW-pMYsfz1w8Ny0Bj8
Resolving drive.google.com (drive.google.com)... 172.217.212.100, 172.217.212.139, 172.217.212.113, ...
Connecting to drive.google.com (drive.google.com)|172.217.212.100|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-04-ak-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/gquoughp2j9j9686dukdjcn69i8sp64b/1596945375000/11118900490791463723/*/1LiJFZd41ofrWoBtW-pMYsfz1w8Ny0Bj8 [following]
Warning: wildcards not supported in HTTP.
--2020-08-09 03:56:45--  https://doc-04-ak-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/gquoughp2j9j9686dukdjcn69i8sp64b/1596945375000/11118900490791463723/*/1LiJFZd41ofrWoBtW-pMYsfz1w8Ny0Bj8
Resolving doc-04-ak-docs.googleusercontent.com (doc-04-ak-docs.googleusercontent.com)... 173.194.198.132, 2607:f8b0:4001:c1c::84
Connecting to doc-04-ak-docs.googleusercontent.com (doc-04-ak-docs.googleusercontent.com)|173.194.198.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/csv]
Saving to: ‘/tmp/songdata.csv’

/tmp/songdata.csv       [   <=>              ]  69.08M   136MB/s    in 0.5s    

2020-08-09 03:56:46 (136 MB/s) - ‘/tmp/songdata.csv’ saved [72436445]

250 Songs

Now we've seen a model trained on just a small sample of songs, and how this often leads to repetition as you get further along in trying to generate new text. Let's switch to using the 250 songs instead, and see if our output improves. This will actually be nearly 10K lines of lyrics, which should be sufficient.

Note that we won't use the full dataset here as it will take up quite a bit of RAM and processing time, but you're welcome to try doing so on your own later. If interested, you'll likely want to use only some of the more common words for the Tokenizer, which will help shrink processing time and memory needed (or else you'd have an output array hundreds of thousands of words long).

Preprocessing

In [3]:
def tokenize_corpus(corpus, num_words=-1):
  # Fit a Tokenizer on the corpus
  if num_words > -1:
    tokenizer = Tokenizer(num_words=num_words)
  else:
    tokenizer = Tokenizer()
  tokenizer.fit_on_texts(corpus)
  return tokenizer

def create_lyrics_corpus(dataset, field):
  # Remove all other punctuation
  dataset[field] = dataset[field].str.replace('[{}]'.format(string.punctuation), '')
  # Make it lowercase
  dataset[field] = dataset[field].str.lower()
  # Make it one long string to split by line
  lyrics = dataset[field].str.cat()
  corpus = lyrics.split('\n')
  # Remove any trailing whitespace
  for l in range(len(corpus)):
    corpus[l] = corpus[l].rstrip()
  # Remove any empty lines
  corpus = [l for l in corpus if l != '']

  return corpus
In [4]:
def tokenize_corpus(corpus, num_words=-1):
  # Fit a Tokenizer on the corpus
  if num_words > -1:
    tokenizer = Tokenizer(num_words=num_words)
  else:
    tokenizer = Tokenizer()
  tokenizer.fit_on_texts(corpus)
  return tokenizer

# Read the dataset from csv - this time with 250 songs
dataset = pd.read_csv('/tmp/songdata.csv', dtype=str)[:250]
# Create the corpus using the 'text' column containing lyrics
corpus = create_lyrics_corpus(dataset, 'text')
# Tokenize the corpus
tokenizer = tokenize_corpus(corpus, num_words=2000)
total_words = tokenizer.num_words

# There should be a lot more words now
print(total_words)
2000

Create Sequences and Labels

In [5]:
sequences = []
for line in corpus:
	token_list = tokenizer.texts_to_sequences([line])[0]
	for i in range(1, len(token_list)):
		n_gram_sequence = token_list[:i+1]
		sequences.append(n_gram_sequence)

# Pad sequences for equal input length 
max_sequence_len = max([len(seq) for seq in sequences])
sequences = np.array(pad_sequences(sequences, maxlen=max_sequence_len, padding='pre'))

# Split sequences between the "input" sequence and "output" predicted word
input_sequences, labels = sequences[:,:-1], sequences[:,-1]
# One-hot encode the labels
one_hot_labels = tf.keras.utils.to_categorical(labels, num_classes=total_words)

Train a (Better) Text Generation Model

With more data, we'll cut off after 100 epochs to avoid keeping you here all day. You'll also want to change your runtime type to GPU if you haven't already (you'll need to re-run the above cells if you change runtimes).

In [6]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense, Bidirectional

model = Sequential()
model.add(Embedding(total_words, 64, input_length=max_sequence_len-1))
model.add(Bidirectional(LSTM(20)))
model.add(Dense(total_words, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
history = model.fit(input_sequences, one_hot_labels, epochs=100, verbose=1)
Epoch 1/100
1480/1480 [==============================] - 22s 15ms/step - loss: 5.9793 - accuracy: 0.0464
Epoch 2/100
1480/1480 [==============================] - 22s 15ms/step - loss: 5.6729 - accuracy: 0.0517
Epoch 3/100
1480/1480 [==============================] - 22s 15ms/step - loss: 5.4353 - accuracy: 0.0706
Epoch 4/100
1480/1480 [==============================] - 22s 15ms/step - loss: 5.2396 - accuracy: 0.0959
Epoch 5/100
1480/1480 [==============================] - 22s 15ms/step - loss: 5.0736 - accuracy: 0.1138
Epoch 6/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.9118 - accuracy: 0.1312
Epoch 7/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.7640 - accuracy: 0.1475
Epoch 8/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.6310 - accuracy: 0.1603
Epoch 9/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.5116 - accuracy: 0.1749
Epoch 10/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.4031 - accuracy: 0.1877
Epoch 11/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.3040 - accuracy: 0.1982
Epoch 12/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.2072 - accuracy: 0.2111
Epoch 13/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.1223 - accuracy: 0.2216
Epoch 14/100
1480/1480 [==============================] - 22s 15ms/step - loss: 4.0432 - accuracy: 0.2312
Epoch 15/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.9682 - accuracy: 0.2407
Epoch 16/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.8968 - accuracy: 0.2493
Epoch 17/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.8372 - accuracy: 0.2578
Epoch 18/100
1480/1480 [==============================] - 21s 15ms/step - loss: 3.7767 - accuracy: 0.2662
Epoch 19/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.7177 - accuracy: 0.2765
Epoch 20/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.6649 - accuracy: 0.2837
Epoch 21/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.6100 - accuracy: 0.2915
Epoch 22/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.5613 - accuracy: 0.3002
Epoch 23/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.5162 - accuracy: 0.3058
Epoch 24/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.4769 - accuracy: 0.3110
Epoch 25/100
1480/1480 [==============================] - 21s 15ms/step - loss: 3.4364 - accuracy: 0.3165
Epoch 26/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.4019 - accuracy: 0.3218
Epoch 27/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.3586 - accuracy: 0.3303
Epoch 28/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.3278 - accuracy: 0.3357
Epoch 29/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.2918 - accuracy: 0.3390
Epoch 30/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.2596 - accuracy: 0.3451
Epoch 31/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.2322 - accuracy: 0.3488
Epoch 32/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.2007 - accuracy: 0.3525
Epoch 33/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.1768 - accuracy: 0.3567
Epoch 34/100
1480/1480 [==============================] - 21s 14ms/step - loss: 3.1389 - accuracy: 0.3624
Epoch 35/100
1480/1480 [==============================] - 21s 15ms/step - loss: 3.1257 - accuracy: 0.3661
Epoch 36/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.0943 - accuracy: 0.3721
Epoch 37/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.0838 - accuracy: 0.3738
Epoch 38/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.0430 - accuracy: 0.3803
Epoch 39/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.0197 - accuracy: 0.3824
Epoch 40/100
1480/1480 [==============================] - 22s 15ms/step - loss: 3.0050 - accuracy: 0.3851
Epoch 41/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.9751 - accuracy: 0.3904
Epoch 42/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.9688 - accuracy: 0.3913
Epoch 43/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.9346 - accuracy: 0.3962
Epoch 44/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.9110 - accuracy: 0.4012
Epoch 45/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.8977 - accuracy: 0.4013
Epoch 46/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.8804 - accuracy: 0.4061
Epoch 47/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.8633 - accuracy: 0.4073
Epoch 48/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.8395 - accuracy: 0.4115
Epoch 49/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.8264 - accuracy: 0.4139
Epoch 50/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.8086 - accuracy: 0.4164
Epoch 51/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.7955 - accuracy: 0.4198
Epoch 52/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.7809 - accuracy: 0.4209
Epoch 53/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.7686 - accuracy: 0.4242
Epoch 54/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.7471 - accuracy: 0.4271
Epoch 55/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.7294 - accuracy: 0.4304
Epoch 56/100
1480/1480 [==============================] - 21s 15ms/step - loss: 2.7173 - accuracy: 0.4324
Epoch 57/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.7288 - accuracy: 0.4312
Epoch 58/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.6924 - accuracy: 0.4354
Epoch 59/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.6754 - accuracy: 0.4402
Epoch 60/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.6699 - accuracy: 0.4394
Epoch 61/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.6519 - accuracy: 0.4439
Epoch 62/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.6439 - accuracy: 0.4441
Epoch 63/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.6315 - accuracy: 0.4471
Epoch 64/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.6133 - accuracy: 0.4494
Epoch 65/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.6082 - accuracy: 0.4509
Epoch 66/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.5930 - accuracy: 0.4528
Epoch 67/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.5828 - accuracy: 0.4551
Epoch 68/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.5670 - accuracy: 0.4570
Epoch 69/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.5577 - accuracy: 0.4601
Epoch 70/100
1480/1480 [==============================] - 21s 15ms/step - loss: 2.5416 - accuracy: 0.4622
Epoch 71/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.5440 - accuracy: 0.4612
Epoch 72/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.5273 - accuracy: 0.4659
Epoch 73/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.5144 - accuracy: 0.4669
Epoch 74/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.5159 - accuracy: 0.4656
Epoch 75/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.4965 - accuracy: 0.4703
Epoch 76/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4850 - accuracy: 0.4709
Epoch 77/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4727 - accuracy: 0.4739
Epoch 78/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.4649 - accuracy: 0.4758
Epoch 79/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4650 - accuracy: 0.4768
Epoch 80/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.4526 - accuracy: 0.4785
Epoch 81/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4348 - accuracy: 0.4837
Epoch 82/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4339 - accuracy: 0.4824
Epoch 83/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4273 - accuracy: 0.4844
Epoch 84/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4132 - accuracy: 0.4843
Epoch 85/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.4075 - accuracy: 0.4868
Epoch 86/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.4022 - accuracy: 0.4869
Epoch 87/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.3978 - accuracy: 0.4886
Epoch 88/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3852 - accuracy: 0.4898
Epoch 89/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.3725 - accuracy: 0.4926
Epoch 90/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3912 - accuracy: 0.4896
Epoch 91/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3636 - accuracy: 0.4940
Epoch 92/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.3585 - accuracy: 0.4949
Epoch 93/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3484 - accuracy: 0.4956
Epoch 94/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.3390 - accuracy: 0.4998
Epoch 95/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3371 - accuracy: 0.4991
Epoch 96/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.3559 - accuracy: 0.4952
Epoch 97/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3367 - accuracy: 0.5007
Epoch 98/100
1480/1480 [==============================] - 21s 14ms/step - loss: 2.3213 - accuracy: 0.5016
Epoch 99/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.3096 - accuracy: 0.5057
Epoch 100/100
1480/1480 [==============================] - 22s 15ms/step - loss: 2.2980 - accuracy: 0.5071

View the Training Graph

In [7]:
import matplotlib.pyplot as plt

def plot_graphs(history, string):
  plt.plot(history.history[string])
  plt.xlabel("Epochs")
  plt.ylabel(string)
  plt.show()

plot_graphs(history, 'accuracy')

Generate better lyrics!

This time around, we should be able to get a more interesting output with less repetition.

In [8]:
seed_text = "im feeling chills"
next_words = 100
  
for _ in range(next_words):
	token_list = tokenizer.texts_to_sequences([seed_text])[0]
	token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
	predicted = np.argmax(model.predict(token_list), axis=-1)
	output_word = ""
	for word, index in tokenizer.word_index.items():
		if index == predicted:
			output_word = word
			break
	seed_text += " " + output_word
print(seed_text)
im feeling chills me one other time you are here is my life in of love life whole day colour sun kids heavens how turned pride man never bye never never said no way you for me through me and all our last night we used to nancy pleading for you to prays misfortune friends joy man fantasy life heavens pensabamos decision youre reason to dont be alright if you is happy baby ill see of you to cry bright goodnight the old little only way you and me and now i bound you baby that you make me smile and the knees

Varying the Possible Outputs

In running the above, you may notice that the same seed text will generate similar outputs. This is because the code is currently always choosing the top predicted class as the next word. What if you wanted more variance in the output?

Switching from model.predict_classes to model.predict_proba will get us all of the class probabilities. We can combine this with np.random.choice to select a given predicted output based on a probability, thereby giving a bit more randomness to our outputs.

In [9]:
# Test the method with just the first word after the seed text
seed_text = "im feeling chills"
next_words = 100
  
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
predicted_probs = model.predict(token_list)[0]
predicted = np.random.choice([x for x in range(len(predicted_probs))], 
                             p=predicted_probs)
# Running this cell multiple times should get you some variance in output
print(predicted)
8
In [10]:
# Use this process for the full output generation
seed_text = "im feeling chills"
next_words = 100
  
for _ in range(next_words):
  token_list = tokenizer.texts_to_sequences([seed_text])[0]
  token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
  predicted_probs = model.predict(token_list)[0]
  predicted = np.random.choice([x for x in range(len(predicted_probs))],
                               p=predicted_probs)
  output_word = ""
  for word, index in tokenizer.word_index.items():
    if index == predicted:
      output_word = word
      break
  seed_text += " " + output_word
print(seed_text)
im feeling chills a only way i will calls returning no tan for sun twice in star sight queen trapped quarter seems means distant fucks police losin will fuse runs heal rockin queen headline for citys dumb dumb ground janies mornin diamond twilight babys girlya las burden spiderman pirouette sit wholl give low joy shut sincere takes means distant softly downtown in midnight distant quarter prison roseyeah found age toys life thats headin the says of you loves a that can sometimes that ive not no reason for me than first angels i been givin two way true two is hand life block